package org.test4j.mock;

import org.test4j.mock.faking.meta.FakeStates;
import org.test4j.mock.faking.meta.TimesVerify;
import org.test4j.module.Test4JListener;

import java.lang.reflect.Method;

/**
 * 和测试框架集成时,处理mock行为的监听
 *
 * @author darui.wu
 */
public class MockListener implements Test4JListener {
    /**
     * 测试类级别保存点
     */
    private static ThreadLocal<Long> savePointForTestClass = new ThreadLocal<>();
    /**
     * 测试方法级别保存点
     */
    private static ThreadLocal<Long> savePointForTestMethod = new ThreadLocal<>();

    private static ThreadLocal<Long> savePointForTestExecute = new ThreadLocal<>();

    @Override
    public void beforeAll(Class testClass) {
        savePointForTestClass.set(FakeStates.getMaxFakeId());
    }

    @Override
    public void beforeMethod(Object target) {
        rollback(savePointForTestExecute, null);
        rollback(savePointForTestMethod, FakeStates.getMaxFakeId());
        TimesVerify.initVerify();
    }

    @Override
    public void beforeExecute(Object target, Method method) {
        rollback(savePointForTestExecute, FakeStates.getMaxFakeId());
    }

    @Override
    public void afterExecute(Object target, Method method, Throwable e) {
        rollback(savePointForTestExecute, null);
        TimesVerify.verify();
    }

    @Override
    public void afterMethod() {
        rollback(savePointForTestExecute, null);
        rollback(savePointForTestMethod, null);
    }

    @Override
    public void afterAll() {
        rollback(savePointForTestExecute, null);
        rollback(savePointForTestMethod, null);
        rollback(savePointForTestClass, null);
    }

    private static void rollback(ThreadLocal<Long> savePoint, Long newSavePoint) {
        FakeStates.rollback(savePoint.get());
        savePoint.set(newSavePoint);
    }
}